// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>
#include "single_layer_tests/experimental_detectron_generate_proposals_single_image.hpp"
#include <common_test_utils/ov_tensor_utils.hpp>

using namespace ov::test;
using namespace ov::test::subgraph;

namespace {

const std::vector<float> min_size = { 0 };
const std::vector<float> nms_threshold = { 0.699999988079071 };
const std::vector<int64_t> post_nms_count = { 6 };
const std::vector<int64_t> pre_nms_count = { 1000 };

const std::vector<std::pair<std::string, std::vector<ov::Tensor>>> inputTensors = {
    {
        "empty",
        {
            // 3
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{3}, {1.0f, 1.0f, 1.0f}),
            // 36 x 4 = 144
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{36, 4}, {
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,

                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,

                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
            // 12 x 2 x 6 = 144
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{12, 2, 6}, {
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,

                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,

                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}),
            // {3 x 2 x 6} = 36
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{3, 2, 6}, {
                5.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 4.0f, 1.0f, 1.0f, 1.0f,
                1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 8.0f, 1.0f})
        }
    },
    {
        "filled",
        {
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{3}, {150.0, 150.0, 1.0}),
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{36, 4}, {
                12.0, 68.0, 102.0, 123.0, 46.0, 80.0,  79.0,  128.0, 33.0, 71.0, 127.0, 86.0,  33.0, 56.0, 150.0, 73.0,
                5.0,  41.0, 93.0,  150.0, 74.0, 66.0,  106.0, 115.0, 17.0, 37.0, 87.0,  150.0, 31.0, 27.0, 150.0, 39.0,
                29.0, 23.0, 112.0, 123.0, 41.0, 37.0,  103.0, 150.0, 8.0,  46.0, 98.0,  111.0, 7.0,  69.0, 114.0, 150.0,
                70.0, 21.0, 150.0, 125.0, 54.0, 19.0,  132.0, 68.0,  62.0, 8.0,  150.0, 101.0, 57.0, 81.0, 150.0, 97.0,
                79.0, 29.0, 109.0, 130.0, 12.0, 63.0,  100.0, 150.0, 17.0, 33.0, 113.0, 150.0, 90.0, 78.0, 150.0, 111.0,
                47.0, 68.0, 150.0, 71.0,  66.0, 103.0, 111.0, 150.0, 4.0,  17.0, 112.0, 94.0,  12.0, 8.0,  119.0, 98.0,
                54.0, 56.0, 120.0, 150.0, 56.0, 29.0,  150.0, 31.0,  42.0, 3.0,  139.0, 92.0,  41.0, 65.0, 150.0, 130.0,
                49.0, 13.0, 143.0, 30.0,  40.0, 60.0,  150.0, 150.0, 23.0, 73.0, 24.0,  115.0, 56.0, 84.0, 107.0, 108.0,
                63.0, 8.0,  142.0, 125.0, 78.0, 37.0,  93.0,  144.0, 40.0, 34.0, 150.0, 46.0,  30.0, 21.0, 150.0, 120.0}),
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{12, 2, 6}, {
                9.062256,   10.883133, 9.8441105,   12.694285,   0.41781136, 8.749107,      14.990341,  6.587644,  1.4206103,
                13.299262,  12.432549, 2.736371,    0.22732796,  6.3361835,  12.268727,     2.1009045,  4.771589,  2.5131326,
                5.610736,   9.3604145, 4.27379,     8.317948,    0.60510135, 6.7446275,     1.0207708,  1.1352817, 1.5785321,
                1.718335,   1.8093798, 0.99247587,  1.3233583,   1.7432803,  1.8534478,     1.2593061,  1.7394226, 1.7686696,
                1.647999,   1.7611449, 1.3119122,   0.03007332,  1.1106564,  0.55669737,    0.2546148,  1.9181818, 0.7134989,
                2.0407224,  1.7211134, 1.8565536,   14.562747,   2.8786168,  0.5927796,     0.2064463,  7.6794515, 8.672126,
                10.139171,  8.002429,  7.002932,    12.6314945,  10.550842,  0.15784842,    0.3194304,  10.752157, 3.709805,
                11.628928,  0.7136225, 14.619964,   15.177284,   2.2824087,  15.381494,     0.16618137, 7.507227,  11.173228,
                0.4923559,  1.8227729, 1.4749299,   1.7833921,   1.2363617,  -0.23659119,   1.5737582,  1.779316,  1.9828427,
                1.0482665,  1.4900246, 1.3563544,   1.5341306,   0.7634312,  4.6216766e-05, 1.6161222,  1.7512476, 1.9363779,
                0.9195784,  1.4906164, -0.03244795, 0.681073,    0.6192401,  1.8033613,     14.146055,  3.4043705, 15.292292,
                3.5295358,  11.138999, 9.952057,    5.633434,    12.114562,  9.427372,      12.384038,  9.583308,  8.427233,
                15.293704,  3.288159,  11.64898,    9.350885,    2.0037227,  13.523184,     4.4176426,  6.1057625, 14.400079,
                8.248259,   11.815807, 15.713364,   1.0023532,   1.3203261,  1.7100681,     0.7407832,  1.09448,   1.7188418,
                1.4412547,  1.4862992, 0.74790007,  0.31571656,  0.6398838,  2.0236106,     1.1869069,  1.7265586, 1.2624544,
                0.09934269, 1.3508598, 0.85212964,  -0.38968498, 1.7059708,  1.6533034,     1.7400402,  1.8123854, -0.43063712}),
            ov::test::utils::create_tensor<float>(ov::element::f32, ov::Shape{3, 2, 6}, {
                0.7719922,  0.35906568, 0.29054508, 0.18124384, 0.5604661,  0.84750974, 0.98948747, 0.009793862, 0.7184191,
                0.5560748,  0.6952493,  0.6732593,  0.3306898,  0.6790913,  0.41128764, 0.34593266, 0.94296855,  0.7348507,
                0.24478768, 0.94024557, 0.05405676, 0.06466125, 0.36244348, 0.07942984, 0.10619422, 0.09412837,  0.9053611,
                0.22870538, 0.9237487,  0.20986171, 0.5067282,  0.29709867, 0.53138554, 0.189101,   0.4786443,   0.88421875}),
        }
    }
};

const std::vector<std::vector<InputShape>> dynamicInputShape = {
    // im_info / anchors / deltas / scores
    static_shapes_to_test_representation({{3}, {36, 4}, {12, 2, 6}, {3, 2, 6}}),
    {
        {{-1}, {{3}}},
        {{-1, -1}, {{36, 4}}},
        {{-1, -1, -1}, {{12, 2, 6}}},
        {{-1, -1, -1}, {{3, 2, 6}}}
    },
    {
        {{{3, 6}}, {{3}}},
        {{{36, 72}, {4, 8}}, {{36, 4}}},
        {{{12, 24}, {2, 4}, {6, 12}}, {{12, 2, 6}}},
        {{{3, 6}, {2, 4}, {6, 12}}, {{3, 2, 6}}}
    }
};

INSTANTIATE_TEST_SUITE_P(
    smoke_ExperimentalDetectronGenerateProposalsSingleImageLayerTest,
    ExperimentalDetectronGenerateProposalsSingleImageLayerTest,
    ::testing::Combine(
        ::testing::ValuesIn(dynamicInputShape),
        ::testing::ValuesIn(min_size),
        ::testing::ValuesIn(nms_threshold),
        ::testing::ValuesIn(post_nms_count),
        ::testing::ValuesIn(pre_nms_count),
        ::testing::ValuesIn(inputTensors),
        ::testing::Values(ov::element::Type_t::f32),
        ::testing::Values(CommonTestUtils::DEVICE_CPU)),
    ExperimentalDetectronGenerateProposalsSingleImageLayerTest::getTestCaseName);
} // namespace
